import itertools
import logging
import os
import re
import subprocess
import tempfile
from datetime import datetime

import matplotlib
import matplotlib.pyplot as plt

matplotlib.use('agg')
plt.style.use('seaborn')


def avg(x: list[float]):
    """Compute the average of a list of numbers"""
    return sum(x) / len(x)


def check_pid(pid):
    """Check if a process with a given PID is running"""
    try:
        os.kill(pid, 0)
    except OSError:
        return False
    else:
        return True


class CpuTop:
    """Monitor the CPU utilization of a process using the top command
    """

    def __init__(self, proc: subprocess.Popen, proc_name: str) -> None:
        """Construct the CpuTop object

        Parameters
        ----------
        proc : subprocess.Popen
            Process to be monitored.

        proc_name : str
            Name of the parent process.

        """
        self.proc = proc
        self.proc_name = proc_name
        self.fp = tempfile.TemporaryFile()

    def parse_results(self):
        """Parse the CPU usage from the results generated by top"""
        cpu_usage = {}
        collecting = False
        last_timestamp = None
        self.fp.seek(0)
        for line in self.fp.readlines():
            res = line.decode().split()
            if len(res) == 0:
                continue

            # The CPU utilization lines start after the table header line,
            # which contains the metric names, including "COMMAND" (the
            # subprocess/thread command name)
            if res[-1] == "COMMAND":
                collecting = True
                continue

            # Each batch-mode print starts with the "top" word
            if res[0] == "top":
                last_timestamp = datetime.strptime(res[2], '%H:%M:%S')
                collecting = False

            if not collecting:
                continue

            # Skip non-result lines
            if len(res) < 12:
                continue

            # Skip threads with the same name as the parent process and a
            # different pid
            thread_pid = int(res[0])
            thread = res[11]
            if thread == self.proc_name and thread_pid != self.proc.pid:
                continue

            # Save the CPU utilization on lists of results for each
            # subprocess/thread. Append the thread PID to the name to
            # distinguish between threads with the same name. For instance,
            # this is helpful when there are repeated blocks in the GR
            # flowgraph whose truncated names coincide on top (given top
            # truncates thread names to 15 characters).
            thread_name = "{}[{}]".format(thread, thread_pid)
            if thread_name not in cpu_usage:
                cpu_usage[thread_name] = (list(), list())

            assert last_timestamp is not None
            cpu_usage[thread_name][0].append(last_timestamp)
            cpu_usage[thread_name][1].append(float(res[8]))

        return cpu_usage

    def run(self):
        """Run top and save results to a temporary file"""
        self.top_ps = subprocess.Popen(
            ['top', '-H', '-b', '-p',
             str(self.proc.pid)],
            stdout=subprocess.PIPE)

        while (self.top_ps.poll() is None and self.proc.poll() is None):
            self.fp.write(self.top_ps.stdout.readline())

    def get_avg_results(self):
        """Parse results and compute the average CPU usage of all threads"""
        cpu_usage = self.parse_results()
        return {k: avg(v[1]) for k, v in cpu_usage.items()}

    def get_avg_results_str(self):
        """Return a string with the average CPU usage of all threads"""
        avg_results = self.get_avg_results()

        return '\n'.join([
            '{:15s} : {:.2f}%'.format(
                re.sub(r'\[.*\]', '', k),  # remove pid from thread name
                v) for k, v in
            sorted(avg_results.items(), key=lambda item: item[1], reverse=True)
        ])

    def plot(self, name: str):
        """Plot average CPU utilization per thread

        Parameters
        ----------
        name : str
            Target plot save name.

        """
        cpu_usage = self.parse_results()

        fig, ax = plt.subplots()
        marker = itertools.cycle(('h', 'p', '^', 'v', ',', '+', '.', 'o', '*'))
        linestyle = itertools.cycle(('-', '--', '-.', ':'))

        # Sort curves by the average CPU utilization in descending order
        for key in sorted(cpu_usage.keys(),
                          key=lambda key: avg(cpu_usage[key][1]),
                          reverse=True):
            timestamps = cpu_usage[key][0]
            block = re.sub(r'\[.*\]', '', key)
            values = cpu_usage[key][1]
            time_axis = [(t - timestamps[0]).total_seconds()
                         for t in timestamps]
            plt.plot(time_axis,
                     values,
                     label=block,
                     marker=next(marker),
                     linestyle=next(linestyle))

        plt.legend()
        ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
        plt.ylabel("CPU Usage (%)")
        plt.xlabel("Seconds")
        plt.grid(True)
        plt.tight_layout()

        # Save into a directory called plots/
        res_dir = "plots"
        if not os.path.isdir(res_dir):
            os.makedirs(res_dir)
        savepath = os.path.join(res_dir, f"{name}_{self.proc.pid}.png")
        logging.info("Saving CPU plot to {}".format(savepath))
        plt.savefig(savepath, dpi=300)
        plt.close()

    def __del__(self):
        self.top_ps.kill()
        self.fp.close()
